{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This demo provides examples of `ImageReader` class from `niftynet.io.image_reader` module.\n",
    "\n",
    "What is `ImageReader`?\n",
    "\n",
    "The main functionality of `ImageReader` is to search a set of folders, return a list of image files, and load the images into memory in an iterative manner.\n",
    "\n",
    "A `tf.data.Dataset` instance can be initialised from an `ImageReader`, this makes the module readily usable as an input op to many tensorflow-based applications.\n",
    "\n",
    "Why `ImageReader`?\n",
    "\n",
    " - designed for medical imaging formats and applications \n",
    " - works well with multi-modal input volumes\n",
    " - works well with `tf.data.Dataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Before the demo..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First make sure the source code is available, and import the module.\n",
    "\n",
    "For NiftyNet installation, please checkout:\n",
    "\n",
    "http://niftynet.readthedocs.io/en/dev/installation.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "niftynet_path = '/Users/bar/Documents/Niftynet/'\n",
    "sys.path.insert(0, niftynet_path)\n",
    "\n",
    "from niftynet.io.image_reader import ImageReader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demonstration purpose we download some demo data to `~/niftynet/data/`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accessing: https://github.com/NifTK/NiftyNetModelZoo\n",
      "anisotropic_nets_brats_challenge_model_zoo: OK. \n",
      "Already downloaded. Use the -r option to download again.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from niftynet.utilities.download import download\n",
    "download('anisotropic_nets_brats_challenge_model_zoo')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use case: loading 3D volumes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from niftynet.io.image_reader import ImageReader\n",
    "\n",
    "data_param = {'MR': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG'}}\n",
    "reader = ImageReader().initialise(data_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader.shapes, reader.tf_dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data using the initialised reader\n",
    "idx, image_data, interp_order = reader(idx=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_data['MR'].shape, image_data['MR'].dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# randomly sample the list of images\n",
    "for _ in range(3):\n",
    "    idx, image_data, _ = reader()\n",
    "    print('{} image: {}'.format(idx, image_data['MR'].shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The images are always read into a 5D-array, representing:\n",
    "\n",
    "`[height, width, depth, time, channels]`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User case: loading pairs of image and label by matching filenames\n",
    "(In this case the loaded arrays are not concatenated.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from niftynet.io.image_reader import ImageReader\n",
    "\n",
    "data_param = {'image': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T2'},\n",
    "              'label': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'Label'}}\n",
    "reader = ImageReader().initialise(data_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image file information (without loading the volumes)\n",
    "reader.get_subject(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx, image_data, interp_order = reader(idx=0)\n",
    "\n",
    "image_data['image'].shape, image_data['label'].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User case: loading multiple modalities of image and label by matching filenames\n",
    "\n",
    "The following code initialises a reader with four modalities, and the `'image'` output is a concatenation of arrays loaded from these files. (The files are concatenated at the fifth dimension)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from niftynet.io.image_reader import ImageReader\n",
    "\n",
    "data_param = {'T1':    {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T1', 'filename_not_contains': 'T1c'},\n",
    "              'T1c':   {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T1c'},\n",
    "              'T2':    {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T2'},\n",
    "              'Flair': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'Flair'},\n",
    "              'label': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'Label'}}\n",
    "grouping_param = {'image': ('T1', 'T1c', 'T2', 'Flair'), 'label':('label',)}\n",
    "reader = ImageReader().initialise(data_param, grouping_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, image_data, _ = reader(idx=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_data['image'].shape, image_data['label'].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More properties\n",
    "The input specification supports additional properties include \n",
    "```python\n",
    "{'csv_file', 'path_to_search',\n",
    " 'filename_contains', 'filename_not_contains',\n",
    " 'interp_order', 'pixdim', 'axcodes', 'spatial_window_size',\n",
    " 'loader'}\n",
    "```\n",
    "see also: http://niftynet.readthedocs.io/en/dev/config_spec.html#input-data-source-section"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using ImageReader with image-level data augmentation layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from niftynet.io.image_reader import ImageReader\n",
    "from niftynet.layer.rand_rotation import RandomRotationLayer as Rotate\n",
    "\n",
    "data_param = {'MR': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG'}}\n",
    "reader = ImageReader().initialise(data_param)\n",
    "\n",
    "rotation_layer = Rotate()\n",
    "rotation_layer.init_uniform_angle([-10.0, 10.0])\n",
    "reader.add_preprocessing_layers([rotation_layer])\n",
    "\n",
    "_, image_data, _ = reader(idx=0)\n",
    "image_data['MR'].shape\n",
    "\n",
    "# import matplotlib.pyplot as plt\n",
    "# plt.imshow(image_data['MR'][:, :, 50, 0, 0])\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using ImageReader with `tf.data.Dataset`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from niftynet.io.image_reader import ImageReader\n",
    "\n",
    "# initialise multi-modal image and label reader\n",
    "data_param = {'T1':    {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T1', 'filename_not_contains': 'T1c'},\n",
    "              'T1c':   {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T1c'},\n",
    "              'T2':    {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'T2'},\n",
    "              'Flair': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'Flair'},\n",
    "              'label': {'path_to_search': '~/niftynet/data/BRATS_examples/HGG',\n",
    "                        'filename_contains': 'Label'}}\n",
    "\n",
    "grouping_param = {'image': ('T1', 'T1c', 'T2', 'Flair'), 'label':('label',)}\n",
    "reader = ImageReader().initialise(data_param, grouping_param)\n",
    "\n",
    "# reader as a generator\n",
    "def image_label_pair_generator():\n",
    "    \"\"\"\n",
    "    A generator wrapper of an initialised reader.\n",
    "    \n",
    "    :yield: a dictionary of images (numpy arrays).\n",
    "    \"\"\"\n",
    "    while True:\n",
    "        _, image_data, _ = reader()\n",
    "        yield image_data\n",
    "\n",
    "# tensorflow dataset\n",
    "dataset = tf.data.Dataset.from_generator(\n",
    "    image_label_pair_generator,\n",
    "    output_types=reader.tf_dtypes)\n",
    "    #output_shapes=reader.shapes)\n",
    "dataset = dataset.batch(1)\n",
    "iterator = dataset.make_initializable_iterator()\n",
    "\n",
    "# run the tensorlfow graph\n",
    "with tf.Session() as sess:\n",
    "    sess.run(iterator.initializer)\n",
    "    for _ in range(3):\n",
    "        data_dict = sess.run(iterator.get_next())\n",
    "        print(data_dict.keys())\n",
    "        print('image: {}, label: {}'.format(\n",
    "            data_dict['image'].shape,\n",
    "            data_dict['label'].shape))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
